Skip to content

llama : add llama_batch_ext #11875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 61 commits into
base: master
Choose a base branch
from

Conversation

ngxson
Copy link
Collaborator

@ngxson ngxson commented Feb 14, 2025

Ref comment: #11292 (comment)

Closes #10381

Migration patterns:

llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
// becomes:
llama_batch_ext_ptr batch = llama_batch_ext_ptr(llama_batch_ext_init(n_kv_max, 1));


common_batch_add(batch, tokens[i], pos, { 0 }, false);
// becomes:
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch.get(), tokens[i], pos, &seq_id, 1, false);


llama_decode(lctx, llama_batch_get_one(tokens.data(), std::min(tokens.size(), (size_t) params.n_batch)));
// becomes:
llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0));
llama_decode_ext(lctx, batch.get());


llama_decode(ctx, batch);
// becomes:
llama_decode_ext(ctx, batch.get());

Current status:

  • This PR currently contains the first proposal of public API that allows hiding llama_batch from public API --> To be discussed
  • Only llama-server works for now
  • TODO: the members of llama_batch can be migrated to cpp types

@ngxson
Copy link
Collaborator Author

ngxson commented Feb 14, 2025

@ggerganov Would you mind having a look on this initial proposal? Thank you!

include/llama.h Outdated
Comment on lines 266 to 272
struct llama_batch_ext_token_info {
llama_token token;
llama_pos pos;
int32_t n_seq_id;
llama_seq_id * seq_id;
int8_t logits;
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might not be very future-proof. Mixed-modality batches would have tokens, embeddings and tensors mixed together in the same batch. So calling llama_batch_ext_get_token_info(batch, i); is not always well-defined because it might not be a token at position i.

Maybe we can postpone this "token_info" API. I think all usages in the examples that require to read back info from the batch can be implemented in the example code without relying on the API. This way we can focus only on implementing only the API for creating batches and adding data to them. Later on when we have a better idea of the implementation, we can add a helper API to get info back from the batches.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree. Furthermore, this API requires doing a copy, so it won't be the best for performance. It's better to remove this API for now.

I think all usages in the examples that require to read back info from the batch can be implemented in the example code without relying on the API.

This kind of logic is currently being used inside llama-server, not sure it appears on any other examples. I think I can make a thin wrapper for llama_batch_ext inside the example code. Feel free to tell me if you have a better idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is removed in 1d6ba97 , a new server_batch wrapper is added to manage token logits placement in the batch

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 1, 2025

OK so I've been able to apply this to various example (not all of them). Would be nice if you can have a quick look @ggerganov before I migrate the rest.

One thing to note, the loop check over tokens in batch (discussed in #11875 (comment)) is used by both server.cpp and embeddings.cpp, so my solution was to create a thin wrapper called common_batch. Looks a bit messy for now, so I'm wondering if in the future we can have a llama_get_embeddings_ext or something that can make this easier.

@ggerganov
Copy link
Member

The common_batch is ok for now.

Looks a bit messy for now, so I'm wondering if in the future we can have a llama_get_embeddings_ext or something that can make this easier.

It seems we rather need something to query the batch, no? How do you imagine llama_get_embeddings_ext to work?

I was thinking something like:

struct llama_batch_ext_part;

llama_batch_ext_part * part = llama_batch_ext_get_part(batch, i);
if (llama_batch_ext_part_is_token(part)) {
    llama_token id = llama_batch_ext_part_get_id(part);
    ... get token id, sequence id, etc. ...
}

But since I'm not 100% about all the details yet related to multi-modal batches, I think it is better to postpone this API for later, and handle the batch information in the user code for now.

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 3, 2025

How do you imagine llama_get_embeddings_ext to work?

I don't have a clear idea yet, but I'm thinking as a developer using libllama in the their program: Whenever I add a token to the batch, in case of text token I need to know:

  • The token (token ID in case of text)
  • The pos
  • The seq_id

So when I retrieve back the logits/embeddings, I would imagine that the get_embeddings function will have one of these 2 signatures:

  • get_embeddings(seq_id) ==> we already had llama_get_embeddings_seq
  • get_embeddings(seq_id, pos) ==> we currently need to read back the tokens from batch

It seems we rather need something to query the batch, no?

Yes we can, and this will be quite similar to my point above. I'm thinking about these 2 options:

  • Having something like llama_batch_ext_query(seq_id, pos) that returns the output_id of the token. This can then be used with llama_get_embeddings_ith(output_id)
  • Or, explicitly has llama_batch_ext_set_output(...) that returns the output_id. That means the logits param will be removed from llama_batch_ext_add_text
  • (Edit) Or, another option, llama_batch_ext_add_text can return the output_id if logits is set to true

@ggerganov
Copy link
Member

Yes we can, and this will be quite similar to my point above. I'm thinking about these 2 options:

Having something like llama_batch_ext_query(seq_id, pos) that returns the output_id of the token. This can then be used with llama_get_embeddings_ith(output_id)
Or, explicitly has llama_batch_ext_set_output(...) that returns the output_id. That means the logits param will be removed from llama_batch_ext_add_text
(Edit) Or, another option, llama_batch_ext_add_text can return the output_id if logits is set to true

Hm, yes. The llama_batch_ext_set_output() idea sounds good.

Btw, this makes me wonder if we should actually move the output buffers for logits and the embeddings to be owned by the llama_batch_ext (currently these buffers are owned by the llama_context and are shared by all batches)?

@ggerganov
Copy link
Member

Now that #12181 has been merged, it should be a good time to get this merged too.

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

Yes thanks for the heads up, I'll focus on finishing this today & tomorrow

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 13, 2025

Btw, this makes me wonder if we should actually move the output buffers for logits and the embeddings to be owned by the llama_batch_ext (currently these buffers are owned by the llama_context and are shared by all batches)?

If the output logits and embeddings are staying are float * or std::vector<float> then yes, I think it will be better to move them to llama_batch_ext (and can be done in a follow-up PR)

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 24, 2025

@ggerganov Could you take a look on this PR this week? Thanks!

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay. I still have some concerns:

  • I am thinking that it would be a good idea to make the llama_batch_ext API to treat tokens and embeddings in a very similar manner. Ultimately, I think we should be able to create batches that contain both tokens and embeddings. For example the call:
    LLAMA_API struct llama_batch_ext * llama_batch_ext_init(
            int32_t n_tokens,
            int32_t n_seq_max);

might be better to define as:

    // either one of these - not sure which one yet
    LLAMA_API struct llama_batch_ext * llama_batch_ext_init(
            struct llama_model * model,
            int32_t n_seq_max);

    // this one will figure out `n_seq_max` from the context
    // maybe actually this one is the best
    LLAMA_API struct llama_batch_ext * llama_batch_ext_init(
            struct llama_context * ctx);

Passing n_tokens to the batch init call was necessary in the past in order to pre-allocate an array with enough size. But it is technically redundant information, because we can add new tokens and embeddings and dynamically resize the batch in libllama as needed. So I think there is no longer need to provide this information.

  • I fill like the embeddings API should mirror the tokens API. So instead of:
    LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd(
                  const float * embd,
                       size_t   n_tokens,
                       size_t   n_embd,
                    llama_pos   pos0,
                 llama_seq_id   seq_id);

    LLAMA_API int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos);

we should probably have something like:

    LLAMA_API int32_t llama_batch_ext_add_embd(
        struct llama_batch_ext * batch,
                   const float * embd, // size would be inferred from the context
                     llama_pos * pos,  // multiple pos per embd
                        size_t   n_pos,
            const llama_seq_id * seq_ids,
                        size_t   n_seq_ids,
                          bool   output);

I think these change would help us think about tokens and embeddings as something very similar.

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 25, 2025

Nice, thanks for the comment. Yes I agree with your points. For llama_batch_ext_init it should take the llama_context * ctx because I think one batch should only be valid for a given context (in case we create multiple contexts from single model). I'll work on this now.

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 25, 2025

Ok so I implemented the new llama_batch_ext_init and llama_batch_ext_add_embd in my last commit: d18a79e

For now, qwen2vl-cli is broken because the llama_batch_ext_add_text API only accepts one single position. We can implement the multi pos per token inside this API itself, but a little bit tricky. Let's take an example: we have n_pos_per_token == 4, then we add tokens at position 1, 2, 3, 4, 5, 6, we expect the pos array to be:

111122223333444455556666

But in reality, cgraph only accepts:

123456123456123456123456

One is the transposed version of another, so will be simple to do: before llama_de/encode we can reorder the pos so that it has the correct layout.

But before implementing this, I just want to check with you if my direction still looks ok.

struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) {
return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max);
struct llama_batch_ext * llama_batch_ext_init(struct llama_context * ctx) {
return llama_batch_ext_init_impl(llama_n_batch(ctx), 0, llama_n_seq_max(ctx));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for now this is a good solution. Later, we will be resizing dynamically and not need the llama_n_batch().

@ggerganov
Copy link
Member

But before implementing this, I just want to check with you if my direction still looks ok.

I think this looks good.

@ngxson
Copy link
Collaborator Author

ngxson commented Mar 25, 2025

I implemented the fix, tested qwen2vl-cli and it seems to work with text. Can't test with image for now since it's also broken on master

The command is: llama-qwen2vl-cli -m ../models/Qwen2-VL-7B-Instruct-Q4_K_M.gguf --mmproj ../models/mmproj-Qwen2-VL-7B-Instruct-f16.gguf --image ../models/bliss.png -p "what do you see?"


A bit confused about this, I found this nice illustration on qwen2vl model page on HF, which shows that qwen only use 3 pos per token. It's also confirmed by the config.json file. However, I'm not sure why in llama.cpp we use up to 4 pos.

image

@ggerganov
Copy link
Member

bit confused about this, I found this nice illustration on qwen2vl model page on HF, which shows that qwen only use 3 pos per token. It's also confirmed by the config.json file. However, I'm not sure why in llama.cpp we use up to 4 pos.

cc @HimariO

@@ -1963,7 +1963,7 @@ struct server_context {
const int32_t n_batch = llama_n_batch(ctx);

// only a single seq_id per token is needed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is obsolete.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think I removed it in one of the commits above, we don't need n_batch anymore so I removed this whole code block

@HimariO
Copy link
Contributor

HimariO commented Mar 27, 2025

A bit confused about this, I found this nice illustration on qwen2vl model page on HF, which shows that qwen only use 3 pos per token. It's also confirmed by the config.json file. However, I'm not sure why in llama.cpp we use up to 4 pos.

The fourth position ID is mainly for future-proofing, in case a newer model that takes 3D/depth input(like SpatialLM) is added. Currently, both Qwen2 & 2.5VL only use 3 position ID per token.

@ggerganov
Copy link
Member

@ngxson We discussed with @slaren these changes and he raised a good point that the batch API does not need to explicitly pass token positions. These can be inferred from the KV cache.

Since not having to pass explicit token positions would simplify the batch API, it's a good idea to take it into account when redesigning it. So I will try to do some KV cache refactoring to support this. When I'm ready, I will come back to this PR and update it respectively.

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 1, 2025

Thanks for the head up! Yes for text batch it would be nice if the position can be inferred from KV.

Please also note that, for multimodal batch, we may also need the N-dimension position. For example, in the case of Qwen2VL we have a normal position + an additional 2D position for each image token. I think what we can do is that for the that the 2D coordinate position can be given by the user, and the "normal" position can still be inferred from KV.

So for example in this PR, the API for Qwen will accepts (2*n_tokens) positions

@ggerganov
Copy link
Member

I think the Qwen2VL image positions could be inferred from the token position:

for (int y = 0; y < ph; y++)
{
for (int x = 0; x < pw; x++)
{
int i = y * pw + x;
mrope_pos[i] = *st_pos_id;
mrope_pos[i + img_tokens] = *st_pos_id + y;
mrope_pos[i + img_tokens * 2] = *st_pos_id + x;
mrope_pos[i + img_tokens * 3] = 0;
}
}
*st_pos_id += std::max(pw, ph);

So ideally, the user would not have to pass those as well.

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 1, 2025

We still need to know the image size pw and ph in order to infer the 2D position. Ofc my goal for the future is to hide all of this inside the library. This is the idea in this comment where I suggest having a struct llama_mm_embd that contains the image embeddings tensor and the size of image pw, ph

In far future (or not that far?), struct llama_mm_embd can be used to also store everything that the text decoder wants to retrieve from multimodal encoder - that could be a series of image in case of video, audio tokens in case of audio input

@ggerganov
Copy link
Member

Yes, the multi-dim positions are complicating things. Not sure what is the best solution.

The problem is that many vision models nowadays also need to know the size of image when passing the embeddings from encoder to decoder.

Maybe with Gemma demonstrating that this is not necessary, new models won't need these complications and we don't have to add support at all. Are there other models other than Qwen2VL that need 2D positions?

When passing these embeddings to decoder, we need to "wrap" each slice between some tokens, then also "wrap" each row of 3 slices between another token:

This seems something that the user code can implement logic (similar how to bos/eos tokens are added). Which models use this pattern?

Anyway, the KV cache refactoring can be done before making a decision about how to handle the images, so we can re-discuss this after that.

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 1, 2025

Just to clarify before writing my response, there are 2 reasons why the image size is needed:

  • For 2D positional embedding like Qwen, where each token has 3 positions (temporal, X, Y)
  • For slices layout --> currently use by a lot of models like MiniCPM-V, SmolVLM, llava 1.6. This technique use normal position and normal causal mask, but they add special tokens to identify rows/cols of slices

Maybe with Gemma demonstrating that this is not necessary, new models won't need these complications and we don't have to add support at all. Are there other models other than Qwen2VL that need 2D positions?

Hmm yeah that could be right. Because M-RoPE is invented by Qwen, I don't see anyone using it for now. Not sure if other models will adopt it in the future.

But please not that gemma 3 does not use slices. This makes working with gemma 3 vision easy, but the current problem with gemma 3 is that the image size is fixed. For bigger images, they need to rely on a technique called "pan and zoom", which essentially a prompting technique that allow the model to "ask" the runtime to zoom the image, then rerun the generation. This is obviously very inefficient.

Models like SmolVLM, MiniCPM-V, Qwen (and maybe many other) are already using slicing technique that I said earlier, so we definitely need to support this in the API.

This seems something that the user code can implement logic (similar how to bos/eos tokens are added). Which models use this pattern?

In fact, better to think is that is the "chat template" for image. While it can be implemented in user code, I think it's better to make it transparent from user POV, as this part is model-specific and even harder than normal chat template for text, this is not something user can easily debug.

In gemma3-cli, you can see that the <start_of_image> token is added from user code. But my goal for the vision API is to hide this behind an API.

@ngxson
Copy link
Collaborator Author

ngxson commented Apr 4, 2025

@ggerganov I've been working on audio input and output recently and I think the API proposed by this PR pretty much correspond to what I need (ofc except for the position, which will be nicer to be hidden from user code).

Having this PR merged could save me some efforts, and more importantly unblock my researches on multimodal API, so I'm wondering if anything I could do on my side to accelerate this a bit more? Thank you!


Also now that the position is hidden from user code, I think we should also somehow modify the llama_kv_* API to make sure that user don't accidentally leave a "hole" in the context. For example, if KV cache has 10 tokens and they delete tokens from position [2, 5), now there will be no API to fill in this "hole" unless we also remove all tokens [5, inf)

@ggerganov
Copy link
Member

My plan for the next steps was to refactor the llama_kv_cache_unified into 2 separate implementations - unified and recurrent (started some initial work in #12695). This will simplify the logic in llama_kv_cache and allow to design a better token position / sequence tracking in libllama which can then be used instead of manually passing positions with the batches. When this is ready, we can update this PR to not pass the positions explicitly.

Also now that the position is hidden from user code, I think we should also somehow modify the llama_kv_* API to make sure that user don't accidentally leave a "hole" in the context. For example, if KV cache has 10 tokens and they delete tokens from position [2, 5), now there will be no API to fill in this "hole" unless we also remove all tokens [5, inf)

The llama_kv_self_seq_ API would still require the user to keep track of the token positions / sequence lengths. But not sure if this can be improved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
android Issues specific to Android examples python python script changes server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor: Allow adding both tokens and embeddings to llama_batch
3 participants